/*
 * Wazuh Vulnerability scanner - Database Feed Manager
 * Copyright (C) 2015, Wazuh Inc.
 * Oct 3, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#ifndef _UPDATE_CVE_DESCRIPTION_HPP
#define _UPDATE_CVE_DESCRIPTION_HPP

#include "cve5_generated.h"
#include "rocksDBWrapper.hpp"
#include "vulnerabilityDescription_generated.h"
#include "vulnerabilityScanner.hpp"
#include <variant>

constexpr auto DEFAULT_ADP_SUBSHORT_NAME {"nvd"};

/**
 * @brief UpdateCVEDescription class.
 *
 */
class UpdateCVEDescription final
{
public:
    /**
     * @brief Reads CVE5 database, creates a vulnerability description flatbuffer and stores it in a specific RocksDB
     * database.
     *
     * @param cve5Flatbuffer CVE5 Flatbuffer.
     * @param feedDatabase rocksDB wrapper instance.
     */
    static void storeVulnerabilityDescription(const cve_v5::Entry* cve5Flatbuffer,
                                              std::shared_ptr<Utils::RocksDBWrapper> feedDatabase)
    {
        auto storeDescriptionLambda = [&](std::variant<const cve_v5::Cna*, const cve_v5::Adp*> container)
        {
            const flatbuffers::Vector<flatbuffers::Offset<cve_v5::Metric>>* metricsArray = nullptr;
            const flatbuffers::Vector<flatbuffers::Offset<cve_v5::Description>>* descriptionArray = nullptr;
            const flatbuffers::Vector<flatbuffers::Offset<cve_v5::Reference>>* referencesArray = nullptr;
            const flatbuffers::Vector<flatbuffers::Offset<cve_v5::ProblemType>>* problemTypesArray = nullptr;
            std::string subShortName;

            std::visit(
                [&](auto&& arg)
                {
                    auto&& forwardedArg = std::forward<decltype(arg)>(arg);

                    metricsArray = forwardedArg->metrics();
                    descriptionArray = forwardedArg->descriptions();
                    referencesArray = forwardedArg->references();
                    problemTypesArray = forwardedArg->problemTypes();
                    subShortName = forwardedArg->providerMetadata()->x_subShortName()
                                       ? forwardedArg->providerMetadata()->x_subShortName()->str()
                                       : forwardedArg->providerMetadata()->shortName()->str();
                },
                container);

            // Required fields to create the vulnerability description.
            float vulnDescFBScoreBase = 0.0;
            std::string vulnDescFBClassificationStr;
            std::string vulnDescFBDescriptionStr;
            std::string vulnDescFBSeverityStr;
            std::string vulnDescFBScoreVersionStr;
            std::string vulnDescFBReferenceStr;
            std::string vulnDescFBAssignerStr;
            std::string vulnDescFBCWEIdStr;
            std::string vulnDescFBDataPublishedStr;
            std::string vulnDescFBDataUpdatedStr;
            std::string vulnDescFBAccessComplexityStr;
            std::string vulnDescFBAttackVectorStr;
            std::string vulnDescFBAuthenticationStr;
            std::string vulnDescFBAvailabilityStr;
            std::string vulnDescFBConfidentialityImpactStr;
            std::string vulnDescFBIntegrityImpactStr;
            std::string vulnDescFBPrivilegesRequiredStr;
            std::string vulnDescFBScopeStr;
            std::string vulnDescFBUserInteractionStr;

            // Extract the description if it exists.
            if (descriptionArray)
            {
                for (const auto& field : *descriptionArray)
                {
                    // We are only interested in the English description.
                    if (field->lang()->str().compare("en") == 0)
                    {
                        vulnDescFBDescriptionStr = field->value()->str();
                        break;
                    }
                }
            }
            // Empty description is not CVE5 Compliant.
            if (vulnDescFBDescriptionStr.empty())
            {
                return;
            }

            // Extract the references if they exist.
            if (referencesArray)
            {
                for (const auto& field : *referencesArray)
                {
                    if (field->url())
                    {
                        vulnDescFBReferenceStr += field->url()->str();
                        vulnDescFBReferenceStr += ", ";
                    }
                }
                // Remove the last comma and space.
                vulnDescFBReferenceStr = vulnDescFBReferenceStr.substr(0, vulnDescFBReferenceStr.size() - 2);
            }
            // Empty reference is not CVE5 Compliant.
            if (vulnDescFBReferenceStr.empty())
            {
                return;
            }

            // Extract the CVSS score and severity if they exist.
            if (metricsArray)
            {
                // We only extract one CVSS metric type. Priority is V3.1, then V3.0, and finally V2.0.
                cve_v5::Metric::FlatBuffersVTableOffset metricCVSSVersion;
                for (const auto& field : *metricsArray)
                {
                    auto metricCVSSV3_1 = field->cvssV3_1();
                    if (metricCVSSV3_1)
                    {
                        metricCVSSVersion = field->VT_CVSSV3_1;
                        vulnDescFBScoreBase = metricCVSSV3_1->baseScore();
                        vulnDescFBSeverityStr =
                            (metricCVSSV3_1->baseSeverity()) ? metricCVSSV3_1->baseSeverity()->str() : "";
                        vulnDescFBScoreVersionStr = (metricCVSSV3_1->version()) ? metricCVSSV3_1->version()->str() : "";
                        vulnDescFBClassificationStr = (field->format()) ? field->format()->str() : "";
                        vulnDescFBAttackVectorStr =
                            (metricCVSSV3_1->attackVector()) ? metricCVSSV3_1->attackVector()->str() : "";
                        vulnDescFBAvailabilityStr =
                            (metricCVSSV3_1->availabilityImpact()) ? metricCVSSV3_1->availabilityImpact()->str() : "";
                        vulnDescFBConfidentialityImpactStr = (metricCVSSV3_1->confidentialityImpact())
                                                                 ? metricCVSSV3_1->confidentialityImpact()->str()
                                                                 : "";
                        vulnDescFBIntegrityImpactStr =
                            (metricCVSSV3_1->integrityImpact()) ? metricCVSSV3_1->integrityImpact()->str() : "";
                        vulnDescFBPrivilegesRequiredStr =
                            (metricCVSSV3_1->privilegesRequired()) ? metricCVSSV3_1->privilegesRequired()->str() : "";
                        vulnDescFBScopeStr = (metricCVSSV3_1->scope()) ? metricCVSSV3_1->scope()->str() : "";
                        vulnDescFBUserInteractionStr =
                            (metricCVSSV3_1->userInteraction()) ? metricCVSSV3_1->userInteraction()->str() : "";
                        // If the higher version is found we do not need to continue the loop.
                        break;
                    }

                    auto metricCVSSV3_0 = field->cvssV3_0();
                    if (metricCVSSV3_0 && metricCVSSVersion != field->VT_CVSSV3_1)
                    {
                        metricCVSSVersion = field->VT_CVSSV3_0;
                        vulnDescFBScoreBase = metricCVSSV3_0->baseScore();
                        vulnDescFBSeverityStr =
                            (metricCVSSV3_0->baseSeverity()) ? metricCVSSV3_0->baseSeverity()->str() : "";
                        vulnDescFBScoreVersionStr = (metricCVSSV3_0->version()) ? metricCVSSV3_0->version()->str() : "";
                        vulnDescFBClassificationStr = (field->format()) ? field->format()->str() : "";
                        vulnDescFBAttackVectorStr =
                            (metricCVSSV3_0->attackVector()) ? metricCVSSV3_0->attackVector()->str() : "";
                        vulnDescFBAvailabilityStr =
                            (metricCVSSV3_0->availabilityImpact()) ? metricCVSSV3_0->availabilityImpact()->str() : "";
                        vulnDescFBConfidentialityImpactStr = (metricCVSSV3_0->confidentialityImpact())
                                                                 ? metricCVSSV3_0->confidentialityImpact()->str()
                                                                 : "";
                        vulnDescFBIntegrityImpactStr =
                            (metricCVSSV3_0->integrityImpact()) ? metricCVSSV3_0->integrityImpact()->str() : "";
                        vulnDescFBPrivilegesRequiredStr =
                            (metricCVSSV3_0->privilegesRequired()) ? metricCVSSV3_0->privilegesRequired()->str() : "";
                        vulnDescFBScopeStr = (metricCVSSV3_0->scope()) ? metricCVSSV3_0->scope()->str() : "";
                        vulnDescFBUserInteractionStr =
                            (metricCVSSV3_0->userInteraction()) ? metricCVSSV3_0->userInteraction()->str() : "";
                    }

                    auto metricCVSSV2_0 = field->cvssV2_0();
                    if (metricCVSSV2_0 && metricCVSSVersion != field->VT_CVSSV3_1 &&
                        metricCVSSVersion != field->VT_CVSSV3_0)
                    {
                        vulnDescFBScoreBase = metricCVSSV2_0->baseScore();
                        vulnDescFBSeverityStr = (vulnDescFBScoreBase < 4.0)   ? "LOW"
                                                : (vulnDescFBScoreBase < 7.0) ? "MEDIUM"
                                                                              : "HIGH";
                        vulnDescFBScoreVersionStr = (metricCVSSV2_0->version()) ? metricCVSSV2_0->version()->str() : "";
                        vulnDescFBClassificationStr = (field->format()) ? field->format()->str() : "";
                        vulnDescFBAccessComplexityStr =
                            (metricCVSSV2_0->accessComplexity()) ? metricCVSSV2_0->accessComplexity()->str() : "";
                        vulnDescFBAuthenticationStr =
                            (metricCVSSV2_0->authentication()) ? metricCVSSV2_0->authentication()->str() : "";
                        vulnDescFBAvailabilityStr =
                            (metricCVSSV2_0->availabilityImpact()) ? metricCVSSV2_0->availabilityImpact()->str() : "";
                        vulnDescFBConfidentialityImpactStr = (metricCVSSV2_0->confidentialityImpact())
                                                                 ? metricCVSSV2_0->confidentialityImpact()->str()
                                                                 : "";
                        vulnDescFBIntegrityImpactStr =
                            (metricCVSSV2_0->integrityImpact()) ? metricCVSSV2_0->integrityImpact()->str() : "";
                    }
                }
            }

            // Extract the metadata if it exists.
            if (const auto cve5Metadata = cve5Flatbuffer->cveMetadata(); cve5Metadata)
            {
                vulnDescFBAssignerStr =
                    (cve5Metadata->assignerShortName()) ? cve5Metadata->assignerShortName()->str() : "";
                vulnDescFBDataPublishedStr =
                    (cve5Metadata->datePublished()) ? cve5Metadata->datePublished()->str() : "";
                vulnDescFBDataUpdatedStr = (cve5Metadata->dateUpdated()) ? cve5Metadata->dateUpdated()->str() : "";
            }

            // Extract the problem types if they exist.
            if (problemTypesArray)
            {
                auto problemTypesDescriptionsArray = problemTypesArray->Get(0);
                if (problemTypesDescriptionsArray)
                {
                    auto descriptionsProblemTypesArray = problemTypesDescriptionsArray->descriptions();
                    if (descriptionsProblemTypesArray)
                    {
                        vulnDescFBCWEIdStr = (descriptionsProblemTypesArray->Get(0)->cweId())
                                                 ? descriptionsProblemTypesArray->Get(0)->cweId()->str()
                                                 : "";
                    }
                }
            }

            // Build the vulnerability description flatbuffer.
            flatbuffers::FlatBufferBuilder builder;

            auto vulnerabilityDescriptionFB =
                NSVulnerabilityScanner::CreateVulnerabilityDescriptionDirect(builder,
                                                                             vulnDescFBAccessComplexityStr.c_str(),
                                                                             vulnDescFBAssignerStr.c_str(),
                                                                             vulnDescFBAttackVectorStr.c_str(),
                                                                             vulnDescFBAuthenticationStr.c_str(),
                                                                             vulnDescFBAvailabilityStr.c_str(),
                                                                             vulnDescFBClassificationStr.c_str(),
                                                                             vulnDescFBConfidentialityImpactStr.c_str(),
                                                                             vulnDescFBCWEIdStr.c_str(),
                                                                             vulnDescFBDataPublishedStr.c_str(),
                                                                             vulnDescFBDataUpdatedStr.c_str(),
                                                                             vulnDescFBDescriptionStr.c_str(),
                                                                             vulnDescFBIntegrityImpactStr.c_str(),
                                                                             vulnDescFBPrivilegesRequiredStr.c_str(),
                                                                             vulnDescFBReferenceStr.c_str(),
                                                                             vulnDescFBScopeStr.c_str(),
                                                                             vulnDescFBScoreBase,
                                                                             vulnDescFBScoreVersionStr.c_str(),
                                                                             vulnDescFBSeverityStr.c_str(),
                                                                             vulnDescFBUserInteractionStr.c_str());

            builder.Finish(vulnerabilityDescriptionFB);

            // Generate the key and column name where the description will be stored.
            const auto cveId {cve5Flatbuffer->cveMetadata()->cveId()->str()};
            const auto descriptionColumn = std::string(DESCRIPTIONS_COLUMN) + "_" + subShortName;

            if (!feedDatabase->columnExists(descriptionColumn))
            {
                feedDatabase->createColumn(descriptionColumn);
            }

            // Store the description in the database.
            const uint8_t* buffer = builder.GetBufferPointer();
            size_t flatbufferSize = builder.GetSize();
            const rocksdb::Slice VulnerabilityDescriptionSlice(reinterpret_cast<const char*>(buffer), flatbufferSize);

            feedDatabase->put(cveId, VulnerabilityDescriptionSlice, descriptionColumn);
        };

        // Store the description for the CNA
        if (cve5Flatbuffer->containers() && cve5Flatbuffer->containers()->cna())
        {
            storeDescriptionLambda(cve5Flatbuffer->containers()->cna());
        }

        // Store the description for the ADPs
        if (cve5Flatbuffer->containers() && cve5Flatbuffer->containers()->adp())
        {
            for (const auto& adp : *cve5Flatbuffer->containers()->adp())
            {
                storeDescriptionLambda(adp);
            }
        }
    }

    /**
     * @brief Deletes a vulnerability description from the database.
     *
     * @param data Flatbuffer object containing the CVE information.
     * @param feedDatabase rocksDB wrapper instance.
     */
    static void removeVulnerabilityDescription(const cve_v5::Entry* data,
                                               std::shared_ptr<Utils::RocksDBWrapper> feedDatabase)
    {
        if (!data->cveMetadata() || !data->cveMetadata()->cveId())
        {
            return;
        }

        const auto columns = feedDatabase->getAllColumns();

        for (const auto& column : columns)
        {
            // If the column start with the prefix DESCRIPTIONS_COLUMN, then it is a description column.
            if (column.find(DESCRIPTIONS_COLUMN) == 0)
            {
                std::string key {data->cveMetadata()->cveId()->str()};
                feedDatabase->delete_(key, column);
            }
        }
    }
};

#endif // _UPDATE_CVE_DESCRIPTION_HPP
